Conversation
klamike
left a comment
There was a problem hiding this comment.
Awesome work! I have a few comments, nothing major
| [deps] | ||
| BatchNLPKernels = "7145f916-0e30-4c9d-93a2-b32b6056125d" | ||
| CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
| ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" | ||
| Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
| LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" | ||
| Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
| Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" |
There was a problem hiding this comment.
Do we need all of these? In particular CUDA, LuxCUDA, ExaModels?
test/runtests.jl
Outdated
| train_state_dual, | ||
| data, | ||
| stopping_criteria = [validation_testset], | ||
| ) |
There was a problem hiding this comment.
src/L2OALM.jl
Outdated
| Keywords: | ||
| - `max_dual`: Maximum value for the target dual variables. | ||
| """ | ||
| function LagrangianDualLoss(num_equal::Int; max_dual = 1e6) |
There was a problem hiding this comment.
Just a note, we should probably eventually have a somewhat standard interface for L2OMethods and their hyperparameters, i.e.
struct ALMMethod <: AbstractL2OMethod
bm::BatchModel
max_dual::Float64
ρ_init::Float64
endor
struct ALMMethod <: AbstractL2OMethod
bm::BatchModel
hyperparameters::Dict{Symbol,Any}
endIdeally that also would help to clean up stuff like
Lines 196 to 197 in 0797bb7
There was a problem hiding this comment.
mutable struct PrimalDualTrainer
primal_model::Lux
primal_training_state::
dual_model::Lux
dual_training_state::
data::Dataloader
|
|
||
| nvar = model.meta.nvar | ||
| ncon = model.meta.ncon | ||
| nθ = length(model.θ) |
There was a problem hiding this comment.
This is such a common thing, BNK should probably have a field for nθ, and expose a frontend like num_parameters, num_variables, num_constraints.
test/runtests.jl
Outdated
| gh_bound = gh_test[1:end-num_equal, :] | ||
| gh_equal = gh_test[end-num_equal+1:end, :] | ||
| dual_hat_bound = dual_hat[1:end-num_equal, :] | ||
| dual_hat_equal = dual_hat[end-num_equal+1:end, :] |
There was a problem hiding this comment.
This is another obvious thing BNK should have -- functions that help you deal with indices
src/L2OALM.jl
Outdated
| Dict{Symbol,Any}( | ||
| :ρ => 1.0, | ||
| :ρmax => 1e6, | ||
| :τ => 0.8, | ||
| :α => 10.0, | ||
| :max_violation => 0.0, | ||
| ), |
There was a problem hiding this comment.
Does Lux let you make these structs instead?
src/L2OALM.jl
Outdated
| Default function that reconciles the state of the dual model after processing a batch of data. | ||
| This function computes the mean dual loss from the batch states. | ||
| """ | ||
| function _reconcile_alm_dual_state(batch_states::Vector{NamedTuple}) |
There was a problem hiding this comment.
| function _reconcile_alm_dual_state(batch_states::Vector{NamedTuple}) | |
| function _reconcile_dual_state(batch_states::Vector{NamedTuple}) |
alm can be removed since that is this whole repo 😄 (needs updates everywhere else, and for the primal version, update_rho, etc. too. let me know if you agree and I can add that commit)
src/L2OALM.jl
Outdated
| function _default_dual_loop(num_equal::Int) | ||
| return TrainingStepLoop( | ||
| LagrangianDualLoss(num_equal), | ||
| [(iter, current_state, hpm) -> iter >= 100 ? true : false], | ||
| Dict{Symbol,Any}(:max_dual => 1e6, :ρ => 1.0), | ||
| [], | ||
| _reconcile_alm_dual_state, | ||
| _pre_hook_dual, | ||
| ) | ||
| end |
There was a problem hiding this comment.
How about exposing the hyperparameters as kwargs here? Same for the primal one.
src/L2OALM.jl
Outdated
| stopping_criterion( | ||
| iter_primal, | ||
| current_state_primal, | ||
| training_step_loop_primal.hyperparameters, |
There was a problem hiding this comment.
Is this some standard Lux API? our stopping criteria don't need the state nor hyperparameters
src/L2OALM.jl
Outdated
| function _pre_hook_primal( | ||
| θ, | ||
| primal_model, | ||
| train_state_primal, | ||
| dual_model, | ||
| train_state_dual, | ||
| bm, | ||
| ) | ||
| # Forward pass for dual model | ||
| dual_hat_k, _ = dual_model(θ, train_state_dual.parameters, train_state_dual.states) | ||
|
|
||
| return (dual_hat_k,) | ||
| end |
There was a problem hiding this comment.
Is there some guidance for what to put in a "pre-hook" vs "loss" ? Do they get treated differently somehow?
There was a problem hiding this comment.
keep hooks but move primal and dual evaluation inside loop with Chainerules.ignore_derivatives
src/L2OALM.jl
Outdated
| mutable struct TrainingStepLoop | ||
| loss_fn::Function | ||
| stopping_criteria::Vector{Function} | ||
| hyperparameters::Dict{Symbol,Any} | ||
| parameter_update_fns::Vector{Function} | ||
| reconcile_state::Function | ||
| pre_hook::Function | ||
| end |
There was a problem hiding this comment.
Why Vector {Function} for parameter_update_fns, stopping_criteria? I think it only ever uses one.
I see for the dual case there is no parameter_update_fn. I guess(x...) -> nothing can work there..
| Θ_train = randn(T, nθ, dataset_size) |> dev_gpu | ||
| Θ_test = randn(T, nθ, dataset_size) |> dev_gpu | ||
|
|
||
| primal_model = feed_forward_builder(nθ, nvar, [320, 320]) |
There was a problem hiding this comment.
Not sure where but we should eventually have some magic for this... something like L2ONN.feed_forward(bm, input=:all_params, output=:all_vars, hidden_sizes=[320,320])
| bm_train = BNK.BatchModel(model, batch_size, config = BNK.BatchModelConfig(:full)) | ||
| bm_test = BNK.BatchModel(model, dataset_size, config = BNK.BatchModelConfig(:full)) |
There was a problem hiding this comment.
| bm_train = BNK.BatchModel(model, batch_size, config = BNK.BatchModelConfig(:full)) | |
| bm_test = BNK.BatchModel(model, dataset_size, config = BNK.BatchModelConfig(:full)) | |
| bm_train = BNK.BatchModel(model, batch_size, config = BNK.BatchModelConfig(:viol_grad)) | |
| bm_test = BNK.BatchModel(model, dataset_size, config = BNK.BatchModelConfig(:viol_grad)) |
viol_grad suffices, to avoid jprod and hessian storage
Project.toml
Outdated
| Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
|
|
||
| [sources] | ||
| BatchNLPKernels = {url = "https://github.com/klamike/BatchNLPKernels.jl"} |
There was a problem hiding this comment.
| BatchNLPKernels = {url = "https://github.com/klamike/BatchNLPKernels.jl"} | |
| BatchNLPKernels = {url = "https://github.com/LearningToOptimize/BatchNLPKernels.jl"} |
Adds Augmented Lagrangian Primal-Dual Learning Method